{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# default_exp problem_types.premask_mlm\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test setup\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "from m3tl.test_base import TestBase\n",
    "from m3tl.input_fn import train_eval_input_fn\n",
    "from m3tl.test_base import test_top_layer\n",
    "test_base = TestBase()\n",
    "params = test_base.params\n",
    "\n",
    "hidden_dim = params.bert_config.hidden_size\n",
    "\n",
    "train_dataset = train_eval_input_fn(params=params)\n",
    "one_batch = next(train_dataset.as_numpy_iterator())\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PreMasked-MLM(premask_mlm)\n",
    "\n",
    "This module includes neccessary part to register pre-masked mlm problem type.\n",
    "\n",
    "The difference between this problem type and `masklm` is that, this problem type requires user to mask their input by their own to gain extra flexibility.\n",
    "\n",
    "It is recommended to use this problem type rather than `masklm`.\n",
    "\n",
    "**IMPORTANT**\n",
    "\n",
    "To use `premask_mlm`, you need to make sure that\n",
    "\n",
    "1. Your input masked with `[MASK]`\n",
    "2. Your label is a list of string with the same length as the number of `[MASK]` token in your input\n",
    "\n",
    "Please note that we only support masking text.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports and utils\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from loguru import logger\n",
    "from m3tl.base_params import BaseParams\n",
    "from m3tl.problem_types.utils import (empty_tensor_handling_loss,\n",
    "                                      nan_loss_handling, pad_to_shape)\n",
    "from m3tl.special_tokens import PREDICT\n",
    "from m3tl.utils import gather_indexes, get_phase, load_transformer_tokenizer\n",
    "from transformers import TFSharedEmbeddings\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Top Layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "\n",
    "class PreMaskMLM(tf.keras.Model):\n",
    "    def __init__(self, params: BaseParams, problem_name: str, input_embeddings: tf.Tensor=None, share_embedding=False) -> None:\n",
    "        super(PreMaskMLM, self).__init__(name=problem_name)\n",
    "        self.params = params\n",
    "        self.problem_name = problem_name\n",
    "        \n",
    "        # same as masklm\n",
    "        if share_embedding is False:\n",
    "            self.vocab_size = self.params.bert_config.vocab_size\n",
    "            self.share_embedding = False\n",
    "        else:\n",
    "            self.vocab_size = input_embeddings.shape[0]\n",
    "            embedding_size = input_embeddings.shape[-1]\n",
    "            share_valid = (self.params.bert_config.hidden_size ==\n",
    "                        embedding_size)\n",
    "            if not share_valid and self.params.share_embedding:\n",
    "                logger.warning(\n",
    "                    'Share embedding is enabled but hidden_size != embedding_size')\n",
    "            self.share_embedding = self.params.share_embedding & share_valid\n",
    "\n",
    "        if self.share_embedding:\n",
    "            self.share_embedding_layer = TFSharedEmbeddings(\n",
    "                vocab_size=self.vocab_size, hidden_size=input_embeddings.shape[1])\n",
    "            self.share_embedding_layer.build([1])\n",
    "            self.share_embedding_layer.weight = input_embeddings\n",
    "        else:\n",
    "            self.share_embedding_layer = tf.keras.layers.Dense(self.vocab_size)\n",
    "\n",
    "    def call(self, inputs):\n",
    "        mode = get_phase()\n",
    "        features, hidden_features = inputs\n",
    "\n",
    "        # masking is done inside the model\n",
    "        seq_hidden_feature = hidden_features['seq']\n",
    "        if mode != PREDICT:\n",
    "            positions = features['{}_masked_lm_positions'.format(self.problem_name)]\n",
    "\n",
    "            # gather_indexes will flatten the seq hidden_states, we need to reshape\n",
    "            # back to 3d tensor\n",
    "            input_tensor = gather_indexes(seq_hidden_feature, positions)\n",
    "            shape_tensor = tf.shape(positions)\n",
    "            shape_list = tf.concat([shape_tensor, [seq_hidden_feature.shape.as_list()[-1]]], axis=0)\n",
    "            input_tensor = tf.reshape(input_tensor, shape=shape_list)\n",
    "            # set_shape to determin rank\n",
    "            input_tensor.set_shape(\n",
    "                [None, None, seq_hidden_feature.shape.as_list()[-1]])\n",
    "        else:\n",
    "            input_tensor = seq_hidden_feature\n",
    "        if self.share_embedding:\n",
    "            mlm_logits = self.share_embedding_layer(\n",
    "                input_tensor, mode='linear')\n",
    "        else:\n",
    "            mlm_logits = self.share_embedding_layer(input_tensor)\n",
    "        if mode != PREDICT:\n",
    "            mlm_labels = features['{}_masked_lm_ids'.format(self.problem_name)]\n",
    "            mlm_labels.set_shape([None, None])\n",
    "            mlm_labels = pad_to_shape(from_tensor=mlm_labels, to_tensor=mlm_logits, axis=1)\n",
    "            # compute loss\n",
    "            mlm_loss = empty_tensor_handling_loss(\n",
    "                mlm_labels,\n",
    "                mlm_logits,\n",
    "                tf.keras.losses.sparse_categorical_crossentropy\n",
    "            )\n",
    "            loss = nan_loss_handling(mlm_loss)\n",
    "            self.add_loss(loss)\n",
    "\n",
    "        return tf.nn.softmax(mlm_logits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_top_layer(PreMaskMLM, problem='weibo_premask_mlm', params=params, sample_features=one_batch, hidden_dim=hidden_dim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get or make label encoder function\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def premask_mlm_get_or_make_label_encoder_fn(params: BaseParams, problem, mode, label_list, *args, **kwargs):\n",
    "    tok = load_transformer_tokenizer(tokenizer_name=params.transformer_tokenizer_name, load_module_name=params.transformer_tokenizer_loading)\n",
    "    params.set_problem_info(problem=problem, info_name='num_classes', info=params.bert_config.vocab_size)\n",
    "    return tok\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Label handing function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def premask_mlm_label_handling_fn(target: str, label_encoder=None, tokenizer=None, decoding_length=None, *args, **kwargs) -> dict:\n",
    "\n",
    "    modal_name = kwargs['modal_name']\n",
    "    modal_type = kwargs['modal_type']\n",
    "    problem = kwargs['problem']\n",
    "    max_predictions_per_seq = 20\n",
    "\n",
    "    if modal_type != 'text':\n",
    "        return {}\n",
    "\n",
    "    tokenized_dict = kwargs['tokenized_inputs']\n",
    "\n",
    "    # create mask lm features\n",
    "    mask_lm_dict = tokenizer(target,\n",
    "                             truncation=True,\n",
    "                             is_split_into_words=True,\n",
    "                             padding='max_length',\n",
    "                             max_length=max_predictions_per_seq,\n",
    "                             return_special_tokens_mask=False,\n",
    "                             add_special_tokens=False,)\n",
    "\n",
    "    mask_token_id = tokenizer(\n",
    "        '[MASK]', add_special_tokens=False, is_split_into_words=False)['input_ids'][0]\n",
    "    masked_lm_positions = [i for i, input_id in enumerate(\n",
    "        tokenized_dict['input_ids']) if input_id == mask_token_id]\n",
    "    # pad masked_lm_positions to max_predictions_per_seq\n",
    "    if len(masked_lm_positions) < max_predictions_per_seq:\n",
    "        masked_lm_positions = masked_lm_positions + \\\n",
    "            [0 for _ in range(max_predictions_per_seq -\n",
    "                              len(masked_lm_positions))]\n",
    "    masked_lm_positions = masked_lm_positions[:max_predictions_per_seq]\n",
    "    masked_lm_ids = np.array(mask_lm_dict['input_ids'], dtype='int32')\n",
    "    masked_lm_weights = np.array(mask_lm_dict['attention_mask'], dtype='int32')\n",
    "    mask_lm_dict = {'{}_masked_lm_positions'.format(problem): masked_lm_positions,\n",
    "                    '{}_masked_lm_ids'.format(problem): masked_lm_ids,\n",
    "                    '{}_masked_lm_weights'.format(problem): masked_lm_weights}\n",
    "    return mask_lm_dict\n"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 2
}
